"""Mechanisms for image reconstruction from parameter gradients."""

import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP

from collections import defaultdict, OrderedDict
from inversefed.nn import MetaMonkey
from .metrics import total_variation as TV
from .metrics import  aux_patch_loss as RP
from .metrics import  april_loss as AP
from .metrics import  group_r
from copy import deepcopy
import nevergrad as ng
import numpy as np
import inversefed.porting as porting
import torch.nn.functional as F

import time

imsize_dict = {
    'ImageNet': 224, 'I128':128, 'I64': 64, 'I32':32, 'I256':256,
    'CIFAR10':32, 'CIFAR100':32, 'FFHQ':512,
    'CA256': 256, 'CA128': 128, 'CA64': 64, 'CA32': 32, 
    'PERM64': 64, 'PERM32': 32,'CelebA':64
}

save_interval=100
construct_group_mean_at = 500
construct_gm_every = 100
construct_gm_switch=1000000
''''''
DEFAULT_CONFIG = dict(signed=False,
                      cost_fn='sim',
                      indices='def',
                      weights='equal',
                      lr=0.1,
                      optim='adam',
                      restarts=1,
                      max_iterations=4800,
                      total_variation=1e-1,
                      bn_stat=1e-1,
                      image_norm=1e-1,
                      z_norm=0,
                      group_lazy=1e-1,
                      init='randn',
                      lr_decay=True,
                      r_patch=0,
                      r_april=0,
                      dataset='CIFAR10',

                      generative_model='',
                      gen_dataset='',
                      giml=False, 
                      gias=False,
                      gias_lr=0.1,
                      gias_iterations=0,
                      )

def _validate_config(config):
    for key in DEFAULT_CONFIG.keys():
        if config.get(key) is None:
            config[key] = DEFAULT_CONFIG[key]
    ''''''
    for key in config.keys():
        if DEFAULT_CONFIG.get(key) is None:
            raise ValueError(f'Deprecated key in config dict: {key}!')
            
    return config


class BNStatisticsHook():
    '''
    Implementation of the forward hook to track feature statistics and compute a loss on them.
    Will compute mean and variance, and will use l2 as a loss
    '''
    def __init__(self, module):
        self.hook = module.register_forward_hook(self.hook_fn)
        self.mean_var=0

    def hook_fn(self, module, input, output):
        # hook co compute deepinversion's feature distribution regularization
        nch = input[0].shape[1]
        mean = input[0].mean([0, 2, 3])
        var = input[0].permute(1, 0, 2, 3).contiguous().view([nch, -1]).var(1, unbiased=False)

        #forcing mean and variance to match between two distributions
        #other ways might work better, i.g. KL divergence
        # r_feature = torch.norm(module.running_var.data - var, 2) + torch.norm(
        #     module.running_mean.data - mean, 2)
        mean_var = [mean, var]
        
        self.mean_var = mean_var
        '''
        if self.mean_var==0:
            self.mean_var = mean_var
        else:
            #print(mean_var)
            self.mean_var =[0.9*self.mean_var[0].detach() +0.1*mean_var[0],0.9*self.mean_var[1].detach() +0.1*mean_var[1]]
        '''
        # must have no output

    def close(self):
        self.hook.remove()


class GradientReconstructor():
    """Instantiate a reconstruction algorithm."""

    def __init__(self, model, mean_std=(0.0, 1.0), config=DEFAULT_CONFIG, num_images=1, G=None, bn_prior=((0.0, 1.0)),dev="0"):
        """Initialize with algorithm setup."""
        #print(config)
        self.config = _validate_config(config)
        #print(self.config)
        self.model = model
        self.device = torch.device('cuda:'+dev) if torch.cuda.is_available() else torch.device('cpu')
        #self.num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
        self.num_gpus = 1
        self.setup = dict(device=self.device, dtype=next(model.parameters()).dtype)

        self.mean_std = mean_std
        self.num_images = num_images
        self.useNG=False
        #BN Statistics
        self.bn_layers = []

        
        if self.config['bn_stat'] > 0:
            for module in model.modules():
                if isinstance(module, nn.BatchNorm2d):
                    self.bn_layers.append(BNStatisticsHook(module))
        ''''''
        self.bn_prior = bn_prior
        #print(self.bn_prior,self.bn_layers)
        
        #Group Regularizer
        self.do_group_mean = False
        self.group_mean = None
        
        self.loss_fn = torch.nn.CrossEntropyLoss(reduction='mean')
        self.iDLG = True

        if G:
            print("Loading G...")
            if self.config['generative_model'] == 'stylegan2':
                self.G, self.G_mapping, self.G_synthesis = G, G.G_mapping, G.G_synthesis
                if self.num_gpus > 1:
                    self.G, self.G_mapping, self.G_synthesis = G, nn.DataParallel(self.G_mapping), nn.DataParallel(self.G_synthesis)
                self.G_mapping.to(self.device)
                self.G_synthesis.to(self.device)
                
                self.G_mapping.requires_grad_(False)
                self.G_synthesis.requires_grad_(True)
                self.G_synthesis.random_noise()
            elif self.config['generative_model'].startswith('stylegan2-ada'):
                self.G, self.G_mapping, self.G_synthesis = G, G.mapping, G.synthesis
                if self.num_gpus > 1:
                    self.G, self.G_mapping, self.G_synthesis = G, nn.DataParallel(self.G_mapping), nn.DataParallel(self.G_synthesis)
                self.G_mapping.to(self.device)
                self.G_synthesis.to(self.device)
                
                self.G_mapping.requires_grad_(False)
                self.G_synthesis.requires_grad_(True)
            else:
                self.G = G
                if self.num_gpus > 1:
                    self.G = nn.DataParallel(self.G)
                self.G.to(self.device)
                self.G.requires_grad_(True)
            self.G.eval() # Disable stochastic dropout and using batch stat.
        elif self.config['generative_model']:
            if self.config['generative_model'] == 'stylegan2':
                self.G, self.G_mapping, self.G_synthesis = porting.load_decoder_stylegan2(self.config, self.device, dataset=self.config['gen_dataset'])
                self.G_mapping.to(self.device)
                self.G_synthesis.to(self.device)
                self.G_mapping.requires_grad_(False)
                self.G_synthesis.requires_grad_(True)
                self.G_mapping.eval()
                self.G_synthesis.eval()
            # elif self.config['generative_model'] == 'stylegan2-ada' or self.config['generative_model'] == 'stylegan2-ada-z':
                # if config['untrained']:
                #     G = porting.load_decoder_stylegan2_untrained(config, self.device, dataset='C10')
                # else:
                # G = porting.load_decoder_stylegan2_ada(self.config, self.device, dataset=self.config['gen_dataset'])
                # self.G = G
            elif self.config['generative_model'] in ['DCGAN']:
                G = porting.load_decoder_dcgan(self.config, self.device, dataset=self.config['gen_dataset'])
                G = G.requires_grad_(True)
                self.G = G
                self.G.eval()
            elif self.config['generative_model'] in ['DCGAN-untrained']:
                G = porting.load_decoder_dcgan_untrained(self.config, self.device, dataset=self.config['gen_dataset'])
                G = G.requires_grad_(True)
                self.G = G
                self.G.eval()
            elif self.config['generative_model'] in ['BigGAN']:
                G=porting.load_decoder_biggan(self.config, self.device, dataset=self.config['gen_dataset'])
                G = G.requires_grad_(True)
                self.G = G
                self.G.eval()
                #print(self.G)
            # print(self.G)
            
        else:
            self.G = None
        self.generative_model_name = self.config['generative_model']
        self.initial_z = None

    def set_initial_z(self, z):
        self.initial_z = z

    def init_dummy_z(self, G, generative_model_name, num_images):
        if self.initial_z is not None:
            dummy_z = self.initial_z.clone().unsqueeze(0) \
                .expand(num_images, self.initial_z.shape[0], self.initial_z.shape[1]) \
                .to(self.device).requires_grad_(True)
        elif generative_model_name.startswith('stylegan2-ada'):
            dummy_z = torch.randn(num_images, 512).to(self.device)
            dummy_z = G.mapping(dummy_z, None, truncation_psi=0.5, truncation_cutoff=8)
            dummy_z = dummy_z.detach().requires_grad_(True)
        elif generative_model_name == 'stylegan2':
            dummy_z = torch.randn(num_images, 512).to(self.device)
            if self.config['gen_dataset'].startswith('I'):
                num_latent_layers = 16
            else:
                num_latent_layers = 18
            dummy_z = self.G_mapping(dummy_z).unsqueeze(1).expand(num_images, num_latent_layers, 512).detach().clone().to(self.device).requires_grad_(True)
            # dummy_noise = G.static_noise(trainable=True)
        elif generative_model_name in ['DCGAN', 'DCGAN-untrained']:
            dummy_z = torch.randn(num_images, 100, 1, 1).to(self.device).requires_grad_(True)
        elif generative_model_name in ['BigGAN']:
            dummy_z = torch.randn(num_images, 128, 1, 1).to(self.device).requires_grad_(True)
        return dummy_z


    def gen_dummy_data(self, G, generative_model_name, dummy_z,labels=0):
        running_device = dummy_z.device
        if generative_model_name.startswith('stylegan2-ada'):
            # @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
            dummy_data = G(dummy_z, noise_mode='random')
        elif generative_model_name.startswith('stylegan2'):
            dummy_data = G(dummy_z)
            if self.config['gen_dataset'].startswith('I'):
                kernel_size = 512 // self.image_size
            else:
                kernel_size = 1024 // self.image_size
            dummy_data = torch.nn.functional.avg_pool2d(dummy_data, kernel_size)

        elif generative_model_name in ['stylegan2-ada-z']:
            dummy_data = G(dummy_z, None, truncation_psi=0.5, truncation_cutoff=8)
        elif generative_model_name in ['DCGAN', 'DCGAN-untrained']:
            dummy_data = G(dummy_z)
        elif generative_model_name in ['BigGAN']:
            dummy_z=torch.reshape(dummy_z,(dummy_z.shape[0],dummy_z.shape[1]))
            #print(dummy_z.device,labels.device)
            dummy_data = G(dummy_z,labels,1)
        
        dm, ds = self.mean_std
        dummy_data = (dummy_data + 1) / 2
        dummy_data = (dummy_data - dm.to(running_device)) / ds.to(running_device)
        return dummy_data

    def count_trainable_params(self, G=None, z=None , x=None):
        n_z, n_G, n_x = 0,0,0
        if G:
            n_z = torch.numel(z) if z.requires_grad else 0
            print(f"z: {n_z}")
            n_G += sum(layer.numel() for layer in G.parameters() if layer.requires_grad)
            print(f"G: {n_G}")
        else:
            n_x = torch.numel(x) if x.requires_grad else 0
            print(f"x: {n_x}")
        self.n_trainable = n_z + n_G + n_x
    
    def reconstruct(self, input_data, labels, img_shape=(3, 32, 32), dryrun=False, eval=True, tol=None):
        """Reconstruct image from gradient."""
        start_time = time.time()
        if eval:
            self.model.eval()

        if torch.is_tensor(input_data[0]):
            input_data = [input_data]
        self.image_size = img_shape[1]
        
        stats = defaultdict(list)
        #print(img_shape)
        x = self._init_images(img_shape)
        
        scores = torch.zeros(self.config['restarts'])
        #print(labels)
        if labels is None:
            if self.num_images == 1 and self.iDLG:
                # iDLG trick:
                last_weight_min = torch.argmin(torch.sum(input_data[-2], dim=-1), dim=-1)
                labels = last_weight_min.detach().reshape((1,)).requires_grad_(False)
                self.reconstruct_label = False
            else:
                # DLG label recovery
                # However this also improves conditioning for some LBFGS cases
                self.reconstruct_label = True

                def loss_fn(pred, labels):
                    labels = torch.nn.functional.softmax(labels, dim=-1)
                    return torch.mean(torch.sum(- labels * torch.nn.functional.log_softmax(pred, dim=-1), 1))
                self.loss_fn = loss_fn
        else:
            #print( labels.shape[0],self.num_images)

            assert labels.shape[0] == self.num_images
            self.reconstruct_label = False
        x=x.to(next(self.model.parameters()).device)
        #print(x.device)
        labels=labels.to(next(self.model.parameters()).device)
        try:
            # labels = [None for _ in range(self.config['restarts'])]
            dummy_z = [None for _ in range(self.config['restarts'])]
            optimizer = [None for _ in range(self.config['restarts'])]
            scheduler = [None for _ in range(self.config['restarts'])]
            _x = [None for _ in range(self.config['restarts'])]
            max_iterations = self.config['max_iterations']

            
            if self.config['gias_iterations'] == 0:
                gias_iterations = 0
            else:
                gias_iterations = self.config['gias_iterations']

            for trial in range(self.config['restarts']):
                _x[trial] = x[trial]

                if self.G:
                    #print("work317")
                    if self.useNG==False:
                        dummy_z[trial] = self.init_dummy_z(self.G, self.generative_model_name, _x[trial].shape[0])
                        #dummy_z[trial]=dummy_z[trial].to(next(self.model.parameters()).device)
                        if self.config['optim'] == 'adam':
                            optimizer[trial] = torch.optim.Adam([dummy_z[trial]], lr=self.config['lr'])
                        elif self.config['optim'] == 'sgd':  # actually gd
                            optimizer[trial] = torch.optim.SGD([dummy_z[trial]], lr=0.01, momentum=0.9, nesterov=True)
                        elif self.config['optim'] == 'LBFGS':
                            optimizer[trial] = torch.optim.LBFGS([dummy_z[trial]])
                        else:
                            raise ValueError()
                    else:
                        parametrization = ng.p.Array(init=np.random.rand(128))
                        #print(parametrization)
                        optimizer = ng.optimizers.registry["CMA"](parametrization=parametrization, budget=500)
                    
                else:
                    _x[trial].requires_grad = True
                    if self.config['optim'] == 'adam':
                        optimizer[trial] = torch.optim.Adam([_x[trial]], lr=self.config['lr'])
                    elif self.config['optim'] == 'sgd':  # actually gd
                        optimizer[trial] = torch.optim.SGD([_x[trial]], lr=0.01, momentum=0.9, nesterov=True)
                    elif self.config['optim'] == 'LBFGS':
                        optimizer[trial] = torch.optim.LBFGS([_x[trial]])
                    else:
                        raise ValueError()

                if self.config['lr_decay']:
                    scheduler[trial] = torch.optim.lr_scheduler.MultiStepLR(optimizer[trial],
                                                                        milestones=[max_iterations // 2.667, max_iterations // 1.6,

                                                                                    max_iterations // 1.142], gamma=0.1)   # 3/8 5/8 7/8
            dm, ds = self.mean_std
            
            if self.G:
                print("Start latent space search")
                #self.count_trainable_params(G=self.G, z=dummy_z[0])
            else:
                print("Start original space search")
                self.count_trainable_params(x=_x[0])
            #print(f"Total number of trainable parameters: {self.n_trainable}")
            
            if self.useNG==False:
                for iteration in range(max_iterations):
                    for trial in range(self.config['restarts']):
                        losses = [0,0,0,0,0,0]
                        # x_trial = _x[trial]
                        # x_trial.requires_grad = True
                        
                        #Group Regularizer
                        #print(self.config['group_lazy'])
                        if self.config['group_lazy'] > 0:

                            if trial == 0 and iteration + 1 == construct_group_mean_at:
                                self.do_group_mean = True
                                #self.group_mean = group_r(_x)
                                self.group_mean = torch.mean(torch.stack(_x), dim=0).detach().clone()
                                #xm= torch.mean(torch.stack(_x), dim=0).detach().clone()
                                #self.group_mean = group_r(_x,self.mean_std,xm)
                                

                            if self.do_group_mean and trial == 0 and (iteration + 1) % construct_gm_every == 0:
                                print("construct group mean")
                                #self.group_mean = group_r(_x)
                                if iteration >= construct_gm_switch:
                                    xm= torch.mean(torch.stack(_x), dim=0).detach().clone()
                                    self.group_mean = group_r(_x,self.mean_std,xm)
                                else:
                                    self.group_mean = torch.mean(torch.stack(_x), dim=0).detach().clone()

                        if self.G:
                            if self.generative_model_name in ['stylegan2','stylegan2-ada','stylegan2-ada-untrained']:
                                _x[trial] = self.gen_dummy_data(self.G_synthesis, self.generative_model_name, dummy_z[trial])

                            elif self.generative_model_name in ['BigGAN']: 
                                #print(labels)
                                if self.useNG==False:
                                    c = torch.nn.functional.one_hot(labels, num_classes=1000).to(labels.device)
                                    _x[trial] = self.gen_dummy_data(self.G, self.generative_model_name, dummy_z[trial],labels=c.float())
                                else:
                                    recommendation = self.optimizer.provide_recommendation()
                                    z_res = torch.from_numpy(recommendation.value).unsqueeze(0).to(self.device)
                                    c = torch.nn.functional.one_hot(labels, num_classes=1000).to(labels.device)
                                    _x[trial] = self.gen_dummy_data(self.G, self.generative_model_name, dummy_z[trial],labels=c.float())
                            else:
                                
                                _x[trial] = self.gen_dummy_data(self.G, self.generative_model_name, dummy_z[trial])
                            self.dummy_z = dummy_z[trial]
                        else:
                            self.dummy_z = None
                        # print(x_trial)
                        #print(_x[trial].shape)
                        closure = self._gradient_closure(optimizer[trial], _x[trial], input_data, labels, losses, float(iteration)/max_iterations)
                        rec_loss = optimizer[trial].step(closure)
                        if self.config['lr_decay']:
                            scheduler[trial].step()

                        with torch.no_grad():
                            # Project into image space
                            _x[trial].data = torch.max(torch.min(_x[trial], (1 - dm) / ds), -dm / ds)

                            if (iteration + 1 == max_iterations) or iteration % save_interval == 0:
                                print(f'It: {iteration}. Rec. loss: {rec_loss.item():2.4E} | tv: {losses[0]:7.4f} | bn: {losses[1]:7.4f} | dist: {losses[2]:7.4f} | gr: {losses[3]:7.4f}| patch: {losses[4]:7.4f}| april: {losses[5]:7.4f}')
                                if self.config['z_norm'] > 0:
                                    print(torch.norm(dummy_z[trial], 2).item())
                #print(_x)
                #print(dummy_z)
                        if dryrun:
                            break
            else:
                #
                #from reconstructor import NGReconstructor
                res_trials = [None]*self.config['restarts']
                loss_trials = [None]*self.config['restarts']
                _x = [None for _ in range(self.config['restarts'])]
                dummy_z = [None for _ in range(self.config['restarts'])]
                for trial in range(self.config['restarts']):
                    print('Processing trial {}/{}.'.format(trial+1, self.config['restarts']))
                    ng_rec = NGReconstructor_B(fl_model=self.model, generator=self.G, loss_fn=self.loss_fn,
                                         num_classes=1000, search_dim=(1,128), strategy="CMA", budget=1000, use_tanh=False, defense_setting=None)
                    #print(input_data)
                    z_res, x_res, img_res, loss_res = ng_rec.reconstruct(input_data[trial],labels)
                    res_trials[trial] = {'z':z_res, 'x':x_res, 'img':img_res}
                    loss_trials[trial] = loss_res
                    _x[trial]=x_res.float()
                    dummy_z[trial]=z_res.view(1, 128, 1, 1).float()
                best_t = np.argmin(loss_trials)
                z_res, x_res, img_res = res_trials[best_t]['z'], res_trials[best_t]['x'], res_trials[best_t]['img']
                
                stats['opt']=best_t
 
        except KeyboardInterrupt:
            print(f'Recovery interrupted manually in iteration {iteration}!')
            pass
        try:

            if self.config['giml']:
                print("Start giml")
                
                
                
                print('Choosing optimal z...')
                
                for trial in range(self.config['restarts']):
                    x[trial] = _x[trial].detach()
                    scores[trial] = self._score_trial(x[trial], input_data, labels)
                    if tol is not None and scores[trial] <= tol:
                        break
                    if dryrun:
                        break
                scores = scores[torch.isfinite(scores)]  # guard against NaN/-Inf scores?
                optimal_index = torch.argmin(scores)
                print(f'Optimal result score: {scores[optimal_index]:2.4f}')
                optimal_z = dummy_z[optimal_index].detach().clone()
                
                self.dummy_z = optimal_z.detach().clone().cpu()
                
                if self.generative_model_name in ['stylegan2','stylegan2-ada','stylegan2-ada-untrained']:
                    G_list = [deepcopy(self.G_synthesis) for _ in range(self.config['restarts'])]
                    for trial in range(self.config['restarts']):
                        G_list[trial].requires_grad_(True)
                else:
                    G_list = [deepcopy(self.G) for _ in range(self.config['restarts'])]

                for trial in range(self.config['restarts']):
                    if self.config['optim'] == 'adam':
                        optimizer[trial] = torch.optim.Adam(G_list[trial].parameters(), lr=self.config['gias_lr'])
                    else:
                        raise ValueError()
        
                    if self.config['lr_decay']:
                        scheduler[trial] = torch.optim.lr_scheduler.MultiStepLR(optimizer[trial],
                                                                        milestones=[gias_iterations // 2.667, gias_iterations // 1.6,

                                                                                    gias_iterations // 1.142], gamma=0.1)   # 3/8 5/8 7/8

                for iteration in range(gias_iterations):
                    for trial in range(self.config['restarts']):
                        losses = [0,0,0,0]
                        # x_trial = _x[trial]
                        # x_trial.requires_grad = True
                        
                        #Group Regularizer
                        if self.config['restarts'] > 1 and trial == 0 and iteration + 1 == construct_group_mean_at and self.config['group_lazy'] > 0:
                            self.do_group_mean = True
                            self.group_mean = torch.mean(torch.stack(_x), dim=0).detach().clone()

                        if self.do_group_mean and trial == 0 and (iteration + 1) % construct_gm_every == 0:
                            print("construct group mean")
                            self.group_mean = torch.mean(torch.stack(_x), dim=0).detach().clone()

                        _x[trial] = self.gen_dummy_data(G_list[trial], self.generative_model_name, optimal_z)
                        # print(x_trial)
                        closure = self._gradient_closure(optimizer[trial], _x[trial], input_data, labels, losses, float(iteration)/max_iterations)
                        rec_loss = optimizer[trial].step(closure)
                        if self.config['lr_decay']:
                            scheduler[trial].step()

                        with torch.no_grad():
                            # Project into image space
                            _x[trial].data = torch.max(torch.min(_x[trial], (1 - dm) / ds), -dm / ds)

                            if (iteration + 1 == gias_iterations) or iteration % save_interval == 0:
                                print(f'It: {iteration}. Rec. loss: {rec_loss.item():2.4E} | tv: {losses[0]:7.4f} | bn: {losses[1]:7.4f} | l2: {losses[2]:7.4f} | gr: {losses[3]:7.4f}')

                        if dryrun:
                            break

            elif self.config['gias'] and gias_iterations !=0:
                print('Choosing optimal z...')
                for trial in range(self.config['restarts']):
                    
                    x[trial] = _x[trial].detach()
                    scores[trial] = self._score_trial(x[trial], input_data, labels)
                    if tol is not None and scores[trial] <= tol:
                        break
                    if dryrun:
                        break
                scores = scores[torch.isfinite(scores)]  # guard against NaN/-Inf scores?
                optimal_index = torch.argmin(scores)
                print(f'Optimal result score: {scores[optimal_index]:2.4f}')
                optimal_z = dummy_z[optimal_index].detach().clone()
                
                self.dummy_z = optimal_z.detach().clone().cpu()
                self.dummy_z=self.dummy_z.to(self.device)
                self.dummy_zs = [None for k in range(self.num_images)]
                # WIP: multiple GPUs                   
                for k in range(self.num_images):
                    self.dummy_zs[k] = torch.unsqueeze(self.dummy_z[k], 0)

                G_list2d = [None for _ in range(self.config['restarts'])]
                # optimizer2d = [None for _ in range(self.config['restarts'])]
                # scheduler2d = [None for _ in range(self.config['restarts'])]
                optimizer = [None for _ in range(self.config['restarts'])]
                for trial in range(self.config['restarts']):
                     if self.generative_model_name in ['stylegan2']:
                        G_list2d[trial] = [deepcopy(self.G_synthesis) for _ in range(self.num_images)]
                     elif self.generative_model_name in ['DCGAN', 'DCGAN-untrained']:
                            G_list2d[trial] = [deepcopy(self.G) for _ in range(self.num_images)]
                     else:
                        from pytorch_pretrained_biggan import (BigGAN, one_hot_from_names, truncated_noise_sample,
                                       save_as_images, display_in_terminal, convert_to_images)
                        if self.config['dataset'].endswith("128"):
                            G_list2d[trial] = [(BigGAN.from_pretrained('biggan-deep-128').to(self.device)).requires_grad_(True) for _ in range(self.num_images)]
                        elif self.config['dataset'].endswith("256"):
                            G_list2d[trial] = [(BigGAN.from_pretrained('biggan-deep-256').to(self.device)).requires_grad_(True) for _ in range(self.num_images)]
                        elif self.config['dataset'].endswith("512"):
                            G_list2d[trial] = [(BigGAN.from_pretrained('biggan-deep-512').to(self.device)).requires_grad_(True) for _ in range(self.num_images)]

                
                if self.num_gpus > 1:
                    print(f"Spliting generators into {self.num_gpus} GPUs...")
                    for trial in range(self.config['restarts']):
                        for k in range(self.num_images):
                            G_list2d[trial][k] = G_list2d[trial][k].to(f'cuda:{k%self.num_gpus}')
                            G_list2d[trial][k].requires_grad_(True)
                            self.dummy_zs[k] = self.dummy_zs[k].to(f'cuda:{k%self.num_gpus}')
                            self.dummy_zs[k].requires_grad_(False)
                else:
                    for trial in range(self.config['restarts']):
                        for k in range(self.num_images):
                            G_list2d[trial][k] = G_list2d[trial][k].to(self.device)
                            G_list2d[trial][k].requires_grad_(True)
                            self.dummy_zs[k] = self.dummy_zs[k].to(self.device)
                            self.dummy_zs[k].requires_grad_(False)

                for trial in range(self.config['restarts']):
                    if self.config['optim'] == 'adam':
                        optimizer[trial] = torch.optim.Adam([{'params': G_list2d[trial][k].parameters()} for k in range(self.num_images)], lr=self.config['gias_lr'])
                    else:
                        raise ValueError()
        
                    if self.config['lr_decay']:
                        scheduler[trial] = torch.optim.lr_scheduler.MultiStepLR(optimizer[trial],
                                            milestones=[gias_iterations // 2.667, gias_iterations // 1.6,
                                            gias_iterations // 1.142], gamma=0.1)   # 3/8 5/8 7/8

                
                

                self.count_trainable_params(G=self.G, z=self.dummy_zs[0])
                print(f"Total number of trainable parameters: {self.n_trainable}")

                print("Start Parameter search")

                for iteration in range(gias_iterations):
                    for trial in range(self.config['restarts']):
                        losses = [0,0,0,0]
                        # x_trial = _x[trial]
                        # x_trial.requires_grad = True
                        
                        #Group Regularizer
                        if self.config['restarts'] > 1 and trial == 0 and iteration + 1 == construct_group_mean_at and self.config['group_lazy'] > 0:
                            self.do_group_mean = True
                            self.group_mean = torch.mean(torch.stack(_x), dim=0).detach().clone()

                        if self.do_group_mean and trial == 0 and (iteration + 1) % construct_gm_every == 0:
                            print("construct group mean")
                            self.group_mean = torch.mean(torch.stack(_x), dim=0).detach().clone()
                        
                        # Load G to GPU
                        # for k in range(self.num_images):
                            # G_list2d[trial][k].to(**self.setup).requires_grad_(True)
                        c = torch.nn.functional.one_hot(labels, num_classes=1000).to(labels.device)
                        _x_trial = [self.gen_dummy_data(G_list2d[trial][k], self.generative_model_name, self.dummy_zs[k],labels=c.float()).to('cpu') for k in range(self.num_images)]
                        _x[trial] = torch.stack(_x_trial).squeeze(1).to(self.device)

                        # print(x_trial)
                        closure = self._gradient_closure(optimizer[trial], _x[trial], input_data, labels, losses, float(iteration)/max_iterations)
                        rec_loss = optimizer[trial].step(closure)
                        if self.config['lr_decay']:
                            scheduler[trial].step()

                        with torch.no_grad():
                            # Project into image space
                            _x[trial].data = torch.max(torch.min(_x[trial], (1 - dm) / ds), -dm / ds)

                            if (iteration + 1 == gias_iterations) or iteration % save_interval == 0:
                                print(f'It: {iteration}. Rec. loss: {rec_loss.item():2.4E} | tv: {losses[0]:7.4f} | bn: {losses[1]:7.4f} | l2: {losses[2]:7.4f} | gr: {losses[3]:7.4f}')

                        # Unload G to CPU
                        # for k in range(self.num_images):
                        #     G_list2d[trial][k].cpu()

                        if dryrun:
                            break
        except KeyboardInterrupt:
            print(f'Recovery interrupted manually in iteration {iteration}!')
            pass

                    
        for trial in range(self.config['restarts']):
            x[trial] = _x[trial].detach()
            scores[trial] = self._score_trial(x[trial], input_data, labels)
            if tol is not None and scores[trial] <= tol:
                break
            if dryrun:
                break
        # Choose optimal result:
        print('Choosing optimal result ...')
        scores = scores[torch.isfinite(scores)]  # guard against NaN/-Inf scores?
        optimal_index = torch.argmin(scores)
        print(f'Optimal result score: {scores[optimal_index]:2.4f}')
        stats['opt'] = scores[optimal_index].item()
        x_optimal = x[optimal_index]
        if self.G and self.config['giml']:
            self.G = G_list[optimal_index]
        elif self.G and self.config['gias'] and self.config['gias_iterations']!=0:
            self.G = G_list2d[optimal_index]

        print(f'Total time: {time.time()-start_time}.')
        #上采样
        if self.config.has_key("ROG"):

            #x_optimal=F.interpolate(x_optimal, scale_factor=2, mode='bicubic')
            
            from .unet import UNet
            p=UNet(in_channels=3, out_channels=3).to(x_optimal.device)
            p.eval()
            #print(x_optimal.device)
            if str(x_optimal.device)=="cuda:1" or str(x_optimal.device)=="cuda:3":
                p.load_state_dict(torch.load('unet_cifar10_2_9.pth'))
            elif str(x_optimal.device)=="cuda:2" or str(x_optimal.device)=="cuda:4":
                p.load_state_dict(torch.load('unet_cifar100_2_9.pth'))
                #print("yes")
            x_optimal=p(x_optimal)
            ''''''
        return x_optimal.detach(), stats


    def reconstruct_theta(self, input_gradients, labels, models, candidate_images, img_shape=(3, 32, 32), dryrun=False, eval=True, tol=None):
        """Reconstruct image from gradient."""
        start_time = time.time()
        if eval:
            self.model.eval()

        stats = defaultdict(list)
        x = self._init_images(img_shape)
        scores = torch.zeros(self.config['restarts'])

        self.reconstruct_label = False

        assert self.config['restarts'] == 1
        max_iterations = self.config['max_iterations']
        num_seq = len(models)
        assert num_seq == len(input_gradients)
        assert num_seq == len(labels)
        for l in labels:
            assert l.shape[0] == self.num_images

        try:
            # labels = [None for _ in range(self.config['restarts'])]
            batch_images = [None for _ in range(num_seq)]
            skip_t = []
            current_labels = [label.item() for label in labels[-1]]
            optimize_target = set()

            for t in range(num_seq):
                batch_images[t] = []
                skip_flag = True
                for label_ in labels[t]:
                    label = label_.item()
                    if label in current_labels:
                        skip_flag = False
                    if label not in candidate_images.keys():
                        candidate_images[label] = torch.randn((1, *img_shape), **self.setup).requires_grad_(True)
                    batch_images[t].append(candidate_images[label])
                    if label not in optimize_target:
                        optimize_target.add(candidate_images[label])
                if skip_flag:
                    skip_t.append(t)

            optimizer = torch.optim.Adam(optimize_target, lr=self.config['lr'])
            if self.config['lr_decay']:
                scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                                    milestones=[max_iterations // 2.667, max_iterations // 1.6,
                                                                                max_iterations // 1.142], gamma=0.1)   # 3/8 5/8 7/8

            dm, ds = self.mean_std
            for iteration in range(max_iterations):
                losses = [0,0,0,0]
                batch_input = []

                for t in range(num_seq):
                    batch_input.append(torch.cat(batch_images[t], dim=0))

                def closure():
                    total_loss = 0
                    optimizer.zero_grad()
                    for t in range(num_seq):
                        models[t].zero_grad()
                    for t in range(num_seq):
                        if t in skip_t:
                            continue
                        loss = self.loss_fn(models[t](batch_input[t]), labels[t])
                        gradient = torch.autograd.grad(loss, models[t].parameters(), create_graph=True)
                        rec_loss = reconstruction_costs([gradient], input_gradients[t],
                                                        cost_fn=self.config['cost_fn'], indices=self.config['indices'],
                                                        weights=self.config['weights'])

                        if self.config['total_variation'] > 0:
                            tv_loss = TV(batch_input[t])
                            rec_loss += self.config['total_variation'] * tv_loss
                            losses[0] = tv_loss
                        total_loss += rec_loss
                    total_loss.backward()
                    return total_loss
                rec_loss = optimizer.step(closure)
                if self.config['lr_decay']:
                    scheduler.step()

                with torch.no_grad():

                    if (iteration + 1 == max_iterations) or iteration % save_interval == 0:
                        print(f'It: {iteration}. Rec. loss: {rec_loss.item():2.4E} | tv: {losses[0]:7.4f} | bn: {losses[1]:7.4f} | l2: {losses[2]:7.4f} | gr: {losses[3]:7.4f}')

                if dryrun:
                    break

        except KeyboardInterrupt:
            print(f'Recovery interrupted manually in iteration {iteration}!')
            pass
                    
        for t in range(num_seq):
            batch_input.append(torch.cat(batch_images[t], dim=0))

        scores = self._score_trial(batch_input[-1], [input_gradients[-1]], labels[-1])
        scores = scores[torch.isfinite(scores)]
        stats['opt'] = scores.item()

        print(f'Total time: {time.time()-start_time}.')
        return batch_input[-1].detach(), stats

    def _init_images(self, img_shape):
        #下采样
        if self.config['init'] == 'randn':
            if self.config.has_key("ROG"):
                return torch.randn((self.config['restarts'], self.num_images, img_shape[0],int(img_shape[1]/2),int(img_shape[2]/2)), **self.setup)
            return torch.randn((self.config['restarts'], self.num_images, img_shape[0],int(img_shape[1]),int(img_shape[2])), **self.setup)
        elif self.config['init'] == 'rand':
            return (torch.rand((self.config['restarts'], self.num_images, *img_shape), **self.setup) - 0.5) * 2
        elif self.config['init'] == 'zeros':
            return torch.zeros((self.config['restarts'], self.num_images, *img_shape), **self.setup)
        else:
            raise ValueError()


    def _gradient_closure(self, optimizer, x_trial, input_gradient, label, losses,stage):

        def closure():
            num_images = label.shape[0]
            num_gradients = len(input_gradient)
            batch_size = num_images // num_gradients
            num_batch = num_images // batch_size
            #print(num_images,num_gradients,batch_size,num_batch)
            total_loss = 0
            optimizer.zero_grad()
            self.model.zero_grad()
            for i in range(num_batch):
                start_idx = i * batch_size
                end_idx = start_idx + batch_size
                batch_input = x_trial[start_idx:end_idx]
                batch_label = label[start_idx:end_idx]
                #print(batch_input.size())
                #exit()
                #上采样
                if self.config.has_key("ROG"):
                    batch_input=F.interpolate(batch_input, scale_factor=2, mode='bicubic')
                #print(batch_input.size())
                #exit()
                loss = self.loss_fn(self.model(batch_input), batch_label)
                gradient = torch.autograd.grad(loss, self.model.parameters(), create_graph=True)
                from compressor import SignSGDCompressor,UniformQuantizer,Topk,QsgdQuantizer
                #gradient = [SignSGDCompressor.decompress(SignSGDCompressor.compress(grad.detach())) for grad in gradient]
                #gradient = [UniformQuantizer.decompress(UniformQuantizer.compress(grad.detach())) for grad in gradient]
                #gradient = [Topk.decompress(Topk.compress(grad.detach())) for grad in gradient]
                #gradient = [QsgdQuantizer.decompress(QsgdQuantizer.compress(grad.detach())) for grad in gradient]

                rec_loss = reconstruction_costs([gradient], input_gradient[i],
                                                cost_fn=self.config['cost_fn'], indices=self.config['indices'],
                                                weights=self.config['weights'])
               
                if self.config['r_patch'] >0 :
                    alpha_grad = 4e-3 if (stage<=0.5) else 2e-3
                    alpha_image = 0 if (stage<=0.5) else 0
                    rec_loss*=alpha_grad
                    
                    self.config['bn_stat']=alpha_image
                
                #if self.config['bn_stat'] > 0:
                #    rec_loss*=1e-3
                #rec_loss*=4e-3
                if self.config['total_variation'] > 0:
                    tv_loss = TV(x_trial)
                    rec_loss += self.config['total_variation'] * tv_loss
                    losses[0] = tv_loss
                    #print(self.config['bn_stat'])
                if self.config['bn_stat'] > 0:
                    
                    
                    if self.config['r_patch'] >0 :
                        
                        bn_prior=[]
                        from torchvision.models import resnet50, ResNet50_Weights
                        pre_model= resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
                        pre_model.cuda()
                        #pre_model = torch.load('model.pth')
                        for name, module in pre_model.named_modules():
                            if isinstance(module, torch.nn.BatchNorm2d):
                                bn_params_dict = module.state_dict()
                                mean_var=[bn_params_dict['running_mean'],bn_params_dict['running_var']]
                                bn_prior.append(mean_var)
                        bn_layers = []
                        for module in pre_model.modules():
                            if isinstance(module, nn.BatchNorm2d):
                                bn_layers.append(BNStatisticsHook(module))
                        x=x_trial.detach().clone()
                        #上采样
                        if self.config.has_key("ROG"):
                            x=F.interpolate(x, scale_factor=2, mode='bicubic')
                        pre_model(x)

                        bn_loss = 0
                        first_bn_multiplier = 10.
                        rescale = [first_bn_multiplier] + [1. for _ in range(len(bn_layers)-1)]
                        for i, (my, pr) in enumerate(zip(bn_layers, bn_prior)):
                            bn_loss += rescale[i] * (torch.norm(pr[0] - my.mean_var[0], 2) + torch.norm(pr[1] - my.mean_var[1], 2))
                        rec_loss += self.config['bn_stat'] * bn_loss

                    else:

                        bn_loss = 0
                        first_bn_multiplier = 10.
                        rescale = [first_bn_multiplier] + [1. for _ in range(len(self.bn_layers)-1)]
                        for i, (my, pr) in enumerate(zip(self.bn_layers, self.bn_prior)):
                            bn_loss += rescale[i] * (torch.norm(pr[0] - my.mean_var[0], 2) + torch.norm(pr[1] - my.mean_var[1], 2))
                            #if i==0:
                                #print(pr[0],my.mean_var[0])
                                #exit()
                        rec_loss += self.config['bn_stat'] * bn_loss
                    losses[1] = bn_loss
                if self.config['image_norm'] > 0:
                    norm_loss = torch.norm(x_trial, 2) / (imsize_dict[self.config['dataset']] ** 2)
                    rec_loss += self.config['image_norm'] * norm_loss
                    losses[2] = norm_loss
                if self.do_group_mean and self.config['group_lazy'] > 0:
                    group_loss =  torch.norm(x_trial - self.group_mean, 2) / (imsize_dict[self.config['dataset']] ** 2)
                    rec_loss += self.config['group_lazy'] * group_loss
                    losses[3] = group_loss
                if self.config['r_patch'] >0 :
                    patch_loss=RP(x_trial)
                    rec_loss += self.config['r_patch'] * patch_loss
                    losses[4] = patch_loss
                if self.config['r_april'] > 0:
                    idx=[]
                    #print(self.model.state_dict().keys())
                    for id,name in  enumerate(self.model.state_dict().keys()):
                        if name == 'module.pos_embedding':
                            #print("*********************************")
                            idx.append(id)
                            #print(id)
                    '''
                    import numpy as np
                    print(np.array([gradient.cpu()]).shape())
                    print(np.array(input_gradient[i].cpu()).shape())
                    exit()
                    '''
                    if(len(idx)==1):
                        idx=idx[0]
                    april_loss= AP([gradient], input_gradient[i],idx)
                    losses[5] = april_loss
                    rec_loss += self.config['r_april'] * april_loss
                if self.config['z_norm'] > 0:
                    if self.dummy_z != None:
                        z_loss = torch.norm(self.dummy_z, 2)
                        rec_loss += self.config['z_norm'] * z_loss
                
                total_loss += rec_loss
            total_loss.backward()
            return total_loss
        return closure

    def _score_trial(self, x_trial, input_gradient, label):
        num_images = label.shape[0]
        num_gradients = len(input_gradient)
        batch_size = num_images // num_gradients
        num_batch = num_images // batch_size

        total_loss = 0
        for i in range(num_batch):
            self.model.zero_grad()
            x_trial.grad = None

            start_idx = i * batch_size
            end_idx = start_idx + batch_size
            batch_input = x_trial[start_idx:end_idx]
            batch_label = label[start_idx:end_idx]
            #上采样
            if self.config.has_key("ROG"):
                batch_input=F.interpolate(batch_input, scale_factor=2, mode='bicubic')
            loss = self.loss_fn(self.model(batch_input), batch_label)
            gradient = torch.autograd.grad(loss, self.model.parameters(), create_graph=False)
            rec_loss = reconstruction_costs([gradient], input_gradient[i],
                                    cost_fn=self.config['cost_fn'], indices=self.config['indices'],
                                    weights=self.config['weights'])
            total_loss += rec_loss
        return total_loss


class FedAvgReconstructor(GradientReconstructor):
    """Reconstruct an image from weights after n gradient descent steps."""

    def __init__(self, model, mean_std=(0.0, 1.0), local_steps=2, local_lr=1e-4,
                 config=DEFAULT_CONFIG, num_images=1, use_updates=True, batch_size=0, bn_prior=((0.0, 1.0)),
                 G=None,known=False):
        """Initialize with model, (mean, std) and config."""
        super().__init__(model, mean_std, config, num_images, G=G)
        self.local_steps = local_steps
        self.local_lr = local_lr
        self.use_updates = use_updates
        self.batch_size = batch_size
        self.bn_prior = bn_prior
        self.known=known
    def _gradient_closure(self, optimizer, x_trial, input_gradient, label, losses,stage):

        def closure():
            num_images = label.shape[0]

            total_loss = 0
            optimizer.zero_grad()
            self.model.zero_grad()
            for i in range(1):

                
                gradient,bn = loss_steps(self.model, x_trial, label, loss_fn=self.loss_fn,
                                        local_steps=self.local_steps, lr=self.local_lr, 
                
                                        num_images=num_images,batch_size=self.batch_size,known=self.known)
                self.bn=bn
  
                
                rec_loss = reconstruction_costs([gradient], input_gradient[i],
                                                cost_fn=self.config['cost_fn'], indices=self.config['indices'],
                                                weights=self.config['weights'])

                if self.config['total_variation'] > 0:
                    tv_loss = TV(x_trial)
                    #print(x_trial.device,tv_loss.device,rec_loss.device)
                    rec_loss += self.config['total_variation'] * tv_loss
                    losses[0] = tv_loss
                if self.config['bn_stat'] > 0:
                    bn_loss = 0
                    first_bn_multiplier = 10.
                    rescale = [first_bn_multiplier] + [1. for _ in range(len(self.bn)-1)]
                    for i, (my, pr) in enumerate(zip(self.bn_layers, self.bn_prior)):
                            bn_loss += rescale[i] * (torch.norm(pr[0] - my.mean_var[0], 2) + torch.norm(pr[1] - my.mean_var[1], 2))
                            #if i==0:
                                #print(pr[0],my.mean_var[0])
                                #exit()
                    #print(self.bn)
                    #for i, (my, pr) in enumerate(zip(self.bn, self.bn_prior)):
                    #    bn_loss += rescale[i] * (torch.norm(pr[0] - my[0], 2) + torch.norm(pr[1] - my[1], 2))
                        #if i==0:
                            #print(pr[0],my[0])
                            #exit()
                    rec_loss += self.config['bn_stat'] * bn_loss
                    losses[1] = bn_loss
                if self.config['image_norm'] > 0:
                    norm_loss = torch.norm(x_trial, 2) / (imsize_dict[self.config['dataset']] ** 2)
                    rec_loss += self.config['image_norm'] * norm_loss
                    losses[2] = norm_loss
                if self.do_group_mean and self.config['group_lazy'] > 0:
                    group_loss =  torch.norm(x_trial - self.group_mean, 2) / (imsize_dict[self.config['dataset']] ** 2)
                    rec_loss += self.config['group_lazy'] * group_loss
                    losses[3] = group_loss
                if self.config['z_norm'] > 0:
                    if self.dummy_z != None:
                        z_loss = torch.norm(self.dummy_z, 2)
                        rec_loss += 1e-3 * z_loss
                total_loss += rec_loss
            total_loss.backward()
            return total_loss
        return closure

    def _score_trial(self, x_trial, input_gradient, label):
        self.model.zero_grad()
        x_trial.grad = None
        num_images = label.shape[0]
        #上采样
        if self.config.has_key("ROG"):
            x_trial=F.interpolate(x_trial, scale_factor=2, mode='bicubic')
        
        loss = self.loss_fn(self.model(x_trial), label)
        gradient = torch.autograd.grad(loss, self.model.parameters(), create_graph=False)

        rec_loss = reconstruction_costs([gradient], input_gradient[0],
                                    cost_fn=self.config['cost_fn'], indices=self.config['indices'],
                                    weights=self.config['weights'])
        return rec_loss
    
def loss_steps(ori_model, inputs, labels, loss_fn=torch.nn.CrossEntropyLoss(), lr=1e-4, num_images=0,local_steps=4, batch_size=0,known=False):
    """Take a few gradient descent steps to fit the model to the given input."""
    
    #patched_model_origin= deepcopy(ori_model)
    #patched_model_origin=MetaMonkey(patched_model_origin)
    model= MetaMonkey(ori_model)
    paras=model.parameters.items()
    #patched_model_origin= deepcopy(model)
    #model.net.train()
    model.net.eval()

    if known:
        local_steps=local_steps
        batch_size=batch_size
    else:
        local_steps=1
        batch_size=num_images
    from torch import optim
    import numpy as np
    optimizer = optim.SGD(model.parameters.values(), lr=lr)
    for i in range(local_steps):
            #shuffer
            
            #state = np.random.get_state()
            #np.random.shuffle(inputs)
            #np.random.set_state(state)
            #np.random.shuffle(labels)       
            
            for j in range(num_images//batch_size):
                #optimizer.zero_grad()
                start=j*batch_size
                #print(batch_size)
                input=inputs[start:start+batch_size]
                label=labels[start:start+batch_size]
                #print(label)
                
                #input = torch.stack(input)
                #label = torch.cat(label)
                #label_ = label
                #print(input.shape)
                #exit()
                #上采样
                
                #input=F.interpolate(input, scale_factor=2, mode='bicubic')
                input = input.to(next(ori_model.parameters()).device)
                label = label.to(next(ori_model.parameters()).device)
                output =model(input,model.parameters)
                
                loss= loss_fn(output, label).sum()
                gradients = torch.autograd.grad(loss,model.parameters.values(),retain_graph=True, create_graph=True, only_inputs=True)
                

                model.parameters = OrderedDict((name, param - lr * grad_part)
                                               for ((name, param), grad_part)
                                               in zip(model.parameters.items(), gradients))
                #for p,grad in zip(model.parameters.values(), gradients):
                #    p.grad = grad
                
                #optimizer.step()

    bn=[]
    
    for name, module in model.net.named_modules():
        if isinstance(module, torch.nn.BatchNorm2d):
            bn_params_dict = module.state_dict()
            #bn.append([0,0])
            
            mean_var=[bn_params_dict['running_mean'],bn_params_dict['running_var']]
            bn.append(mean_var)
            ''''''
    ''''''
    #bn=[my.mean_var for my in bn_layers]
    update = OrderedDict((name, param_origin - param)
                                               for ((name, param), (name_origin, param_origin))
                                               in zip(model.parameters.items(), paras))
    #print(patched_model.parameters.values())
    #exit()
    #import gc
    #gc.collect()
    return list(update.values()),bn


def reconstruction_costs(gradients, input_gradient, cost_fn='l2', indices='def', weights='equal'):
    """Input gradient is given data."""
    if isinstance(indices, list):
        pass
    elif indices == 'def':
        indices = torch.arange(len(input_gradient))
    elif indices == 'batch':
        indices = torch.randperm(len(input_gradient))[:8]
    elif indices == 'topk-1':
        _, indices = torch.topk(torch.stack([p.norm() for p in input_gradient], dim=0), 4)
    elif indices == 'top10':
        _, indices = torch.topk(torch.stack([p.norm() for p in input_gradient], dim=0), 10)
    elif indices == 'top50':
        _, indices = torch.topk(torch.stack([p.norm() for p in input_gradient], dim=0), 50)
    elif indices in ['first', 'first4']:
        indices = torch.arange(0, 4)
    elif indices == 'first5':
        indices = torch.arange(0, 5)
    elif indices == 'first10':
        indices = torch.arange(0, 10)
    elif indices == 'first50':
        indices = torch.arange(0, 50)
    elif indices == 'last5':
        indices = torch.arange(len(input_gradient))[-5:]
    elif indices == 'last10':
        indices = torch.arange(len(input_gradient))[-10:]
    elif indices == 'last50':
        indices = torch.arange(len(input_gradient))[-50:]
    elif indices == 'head':
        m=len(input_gradient)
        indices = torch.arange(0,int(0.33*m))
    elif indices == 'middle':
        indices = torch.arange(int(0.33*m),int(0.67*m))
    elif indices == 'tail':
        indices = torch.arange(int(0.67*m),m)
    elif indices == 'NoBN':
        indices = torch.arange(len(input_gradient))
        to_remove = [1, 5, 9,12,16,19,23,26]
        mask = torch.zeros(len(input_gradient), dtype=torch.bool)
        mask[to_remove] = True
        indices = indices.masked_select(~mask)
    else:
        raise ValueError()

    ex = input_gradient[0]
    if weights == 'linear':
        weights = torch.arange(len(input_gradient), 0, -1, dtype=ex.dtype, device=ex.device) / len(input_gradient)
    elif weights == 'exp':
        weights = torch.arange(len(input_gradient), 0, -1, dtype=ex.dtype, device=ex.device)
        weights = weights.softmax(dim=0)
        weights = weights / weights[0]
    else:
        weights = input_gradient[0].new_ones(len(input_gradient))

    total_costs = 0
    for trial_gradient in gradients:
        pnorm = [0, 0]
        costs = 0
        if indices == 'topk-2':
            _, indices = torch.topk(torch.stack([p.norm().detach() for p in trial_gradient], dim=0), 4)
        for i in indices:
            if cost_fn == 'l2':
                costs += ((trial_gradient[i] - input_gradient[i]).pow(2)).sum() * weights[i]
            elif cost_fn.startswith('compressed'):
                ratio = float(cost_fn[10:])
                k = int(trial_gradient[i].flatten().shape[0] * (1 - ratio))
                k = max(k, 1)

                trial_flatten = trial_gradient[i].flatten()
                trial_threshold = torch.min(torch.topk(torch.abs(trial_flatten), k, 0, largest=True, sorted=False)[0])
                trial_mask = torch.ge(torch.abs(trial_flatten), trial_threshold)
                trial_compressed = trial_flatten * trial_mask

                input_flatten = input_gradient[i].flatten()
                input_threshold = torch.min(torch.topk(torch.abs(input_flatten), k, 0, largest=True, sorted=False)[0])
                input_mask = torch.ge(torch.abs(input_flatten), input_threshold)
                input_compressed = input_flatten * input_mask
                costs += ((trial_compressed - input_compressed).pow(2)).sum() * weights[i]
            elif cost_fn.startswith('sim_cmpr'):
                ratio = float(cost_fn[8:])
                k = int(trial_gradient[i].flatten().shape[0] * (1 - ratio))
                k = max(k, 1)
                
                input_flatten = input_gradient[i].flatten()
                input_threshold = torch.min(torch.topk(torch.abs(input_flatten), k, 0, largest=True, sorted=False)[0])
                input_mask = torch.ge(torch.abs(input_flatten), input_threshold)
                input_compressed = input_flatten * input_mask

                trial_flatten = trial_gradient[i].flatten()
                # trial_threshold = torch.min(torch.topk(torch.abs(trial_flatten), k, 0, largest=True, sorted=False)[0])
                # trial_mask = torch.ge(torch.abs(trial_flatten), trial_threshold)
                trial_compressed = trial_flatten * input_mask

                
                costs -= (trial_compressed * input_compressed).sum() * weights[i]
                pnorm[0] += trial_compressed.pow(2).sum() * weights[i]
                pnorm[1] += input_compressed.pow(2).sum() * weights[i]

            elif cost_fn == 'l1':
                costs += ((trial_gradient[i] - input_gradient[i]).abs()).sum() * weights[i]
            elif cost_fn == 'max':
                costs += ((trial_gradient[i] - input_gradient[i]).abs()).max() * weights[i]
            elif cost_fn == 'sim':
                #print(trial_gradient[i],input_gradient[i])
                costs -= (trial_gradient[i] * input_gradient[i]).sum() * weights[i]
                pnorm[0] += trial_gradient[i].pow(2).sum() * weights[i]
                pnorm[1] += input_gradient[i].pow(2).sum() * weights[i]
            elif cost_fn == 'simlocal':
                costs += 1 - torch.nn.functional.cosine_similarity(trial_gradient[i].flatten(),
                                                                input_gradient[i].flatten(),
                                                                0, 1e-10) * weights[i]
        if cost_fn.startswith('sim'):
            costs = 1 + costs / pnorm[0].sqrt() / pnorm[1].sqrt()

        # Accumulate final costs
        total_costs += costs
    return total_costs / len(gradients)



from tqdm import tqdm
from pytorch_pretrained_biggan import convert_to_images, truncated_noise_sample

class NGReconstructor():
    
    """
    Reconstruction for BigGAN

    """
    def __init__(self, fl_model, generator, loss_fn, num_classes=1000, search_dim=(128,), strategy='CMA', budget=500, use_tanh=True, use_weight=False, defense_setting=None):

        self.generator = generator
        self.budget = budget
        self.search_dim = search_dim
        self.use_tanh = use_tanh
        self.num_samples = 50
        self.weight = None
        self.defense_setting = defense_setting

        parametrization = ng.p.Array(init=np.random.rand(search_dim[0]))
        self.optimizer = ng.optimizers.registry[strategy](parametrization=parametrization, budget=budget)

        self.fl_setting = {'loss_fn':loss_fn, 'fl_model':fl_model, 'num_classes':num_classes}

        if use_weight:
            self.weight = np.ones(62,)
            for i in range(0, 20):
                self.weight[3*i:3*(i+1)] /= 2**i


    def evaluate_loss(self, z, labels, input_gradient):
        return self.ng_loss(z=z, input_gradient=input_gradient, metric='l2',
                        labels=labels, generator=self.generator, weight=self.weight,
                        use_tanh=self.use_tanh, defense_setting=self.defense_setting, **self.fl_setting
                       )

    def reconstruct(self, input_gradient,labels, use_pbar=True):

        labels = self.infer_label(input_gradient)
        print('Inferred label: {}'.format(labels))

        if self.defense_setting is not None:
            if 'clipping' in self.defense_setting:
                total_norm = torch.norm(torch.stack([torch.norm(g, 2) for g in input_gradient]), 2)
                self.defense_setting['clipping'] = total_norm.item()
                print('Estimated defense parameter: {}'.format(self.defense_setting['clipping']))
            if 'compression' in self.defense_setting:
                n_zero, n_param = 0, 0
                for i in range(len(input_gradient)):
                    n_zero += torch.sum(input_gradient[i]==0)
                    n_param += torch.numel(input_gradient[i])
                self.defense_setting['compression'] = 100 * (n_zero/n_param).item()
                print('Estimated defense parameter: {}'.format(self.defense_setting['compression']))

        c = torch.nn.functional.one_hot(labels, num_classes=self.fl_setting['num_classes']).to(input_gradient[0].device)

        pbar = tqdm(range(self.budget)) if use_pbar else range(self.budget)

        for r in pbar:
            ng_data = [self.optimizer.ask() for _ in range(self.num_samples)]
            loss = [self.evaluate_loss(z=ng_data[i].value, labels=labels, input_gradient=input_gradient) for i in range(self.num_samples)]
            for z, l in zip(ng_data, loss):
                self.optimizer.tell(z, l)

            if use_pbar:
                pbar.set_description("Loss {:.6}".format(np.mean(loss)))
            else:
                print("Round {} - Loss {:.6}".format(r, np.mean(loss)))


        recommendation = self.optimizer.provide_recommendation()
        z_res = torch.from_numpy(recommendation.value).unsqueeze(0).to(input_gradient[0].device)
        if self.use_tanh:
            z_res = z_res.tanh()
        loss_res = self.evaluate_loss(recommendation.value, labels, input_gradient)
        with torch.no_grad():
            x_res = self.generator(z_res.float(), c.float(), 1)
        x_res = nn.functional.interpolate(x_res, size=(256, 256), mode='area')
        img_res = convert_to_images(x_res.cpu())

        return z_res, x_res, img_res, loss_res

    @staticmethod
    def infer_label(input_gradient, num_inputs=1):
        last_weight_min = torch.argsort(torch.sum(input_gradient[-2], dim=-1), dim=-1)[:num_inputs]
        labels = last_weight_min.detach().reshape((-1,)).requires_grad_(False)
        return labels

    @staticmethod
    def ng_loss(z, # latent variable to be optimized
                loss_fn, # loss function for FL model
                input_gradient,
                labels,
                generator,
                fl_model,
                num_classes=1000,
                metric='l2',
                use_tanh=True,
                weight=None, # weight to be applied when calculating the gradient matching loss
                defense_setting=None # adaptive attack against defense
               ):

        z = torch.Tensor(z).unsqueeze(0).to(input_gradient[0].device)
        if use_tanh:
            z = z.tanh()

        c = torch.nn.functional.one_hot(labels, num_classes=num_classes).to(input_gradient[0].device)

        with torch.no_grad():
            x = generator(z, c.float(), 1)

        x = nn.functional.interpolate(x, size=(256, 256), mode='area')

        # compute the trial gradient
        #print(c.float().shape)
        #print(fl_model(x).shape)
        #print(labels.shape)
        
        loss_fn=Classification()
        target_loss, _, _ = loss_fn(fl_model(x), labels)
        trial_gradient = torch.autograd.grad(target_loss, fl_model.parameters())
        trial_gradient = [grad.detach() for grad in trial_gradient]

        # adaptive attack against defense
        if defense_setting is not None:
            if 'noise' in defense_setting:
                pass
            if 'clipping' in defense_setting:
                trial_gradient = defense.gradient_clipping(trial_gradient, bound=defense_setting['clipping'])
            if 'compression' in defense_setting:
                trial_gradient = defense.gradient_compression(trial_gradient, percentage=defense_setting['compression'])
            if 'representation' in defense_setting: # for ResNet
                mask = input_gradient[-2][0]!=0
                trial_gradient[-2] = trial_gradient[-2] * mask

        if weight is not None:
            assert len(weight) == len(trial_gradient)
        else:
            weight = [1]*len(trial_gradient)

        # calculate l2 norm
        dist = 0
        for i in range(len(trial_gradient)):
            if metric == 'l2':
                dist += ((trial_gradient[i] - input_gradient[i]).pow(2)).sum()*weight[i]
            elif metric == 'l1':
                dist += ((trial_gradient[i] - input_gradient[i]).abs()).sum()*weight[i]
        dist /= len(trial_gradient)

        if not use_tanh:
            KLD = -0.5 * torch.sum(1 + torch.log(torch.std(z.squeeze(), unbiased=False, axis=-1).pow(2) + 1e-10) - torch.mean(z.squeeze(), axis=-1).pow(2) - torch.std(z.squeeze(), unbiased=False, axis=-1).pow(2))
            dist += 0.1*KLD

        return dist.item()
class NGReconstructor_Reg():
    """
    Reconstruction for BigGAN

    """
    def __init__(self, fl_model,loss_fn, num_classes=1000, search_dim=(3,128,128), bs=1,strategy='CMA', budget=500, use_tanh=True, use_weight=False, defense_setting=None):

        self.bs=bs
        self.budget = budget
        self.search_dim = search_dim
        self.use_tanh = use_tanh
        self.num_samples = 500
        
        self.weight = None
        self.defense_setting = defense_setting
        self.model=fl_model
        #dm = 0.5
        #ds = 0.2
        dm = 0.13066373765468597
        ds = 0.30810782313346863
        parametrization = ng.p.Array(init=np.random.rand(bs,search_dim[0],search_dim[1],search_dim[2])).set_bounds(-dm / ds, (1 - dm) / ds)
        self.optimizer = ng.optimizers.registry[strategy](parametrization=parametrization, budget=budget)

        self.fl_setting = {'loss_fn':loss_fn, 'fl_model':fl_model, 'num_classes':num_classes}

        if use_weight:
            self.weight = np.ones(62,)
            for i in range(0, 20):
                self.weight[3*i:3*(i+1)] /= 2**i


    def evaluate_loss(self, dummyx, labels, input_gradient):
        return self.ng_loss(dummyx=dummyx, model=self.model,input_gradient=input_gradient, metric='sim',
                        labels=labels,  weight=self.weight,
                        use_tanh=self.use_tanh, defense_setting=self.defense_setting, **self.fl_setting
                       )

    def reconstruct(self, input_gradient,labels, use_pbar=True):

        labels = labels
        print('Inferred label: {}'.format(labels))

        if self.defense_setting is not None:
            if 'clipping' in self.defense_setting:
                total_norm = torch.norm(torch.stack([torch.norm(g, 2) for g in input_gradient]), 2)
                self.defense_setting['clipping'] = total_norm.item()
                print('Estimated defense parameter: {}'.format(self.defense_setting['clipping']))
            if 'compression' in self.defense_setting:
                n_zero, n_param = 0, 0
                for i in range(len(input_gradient)):
                    n_zero += torch.sum(input_gradient[i]==0)
                    n_param += torch.numel(input_gradient[i])
                self.defense_setting['compression'] = 100 * (n_zero/n_param).item()
                print('Estimated defense parameter: {}'.format(self.defense_setting['compression']))

        c = torch.nn.functional.one_hot(labels, num_classes=self.fl_setting['num_classes']).to(input_gradient[0].device)

        pbar = tqdm(range(self.budget)) if use_pbar else range(self.budget)

        for r in pbar:
            ng_data = [self.optimizer.ask() for _ in range(self.num_samples)]
            loss = [self.evaluate_loss(dummyx=ng_data[i].value, labels=labels, input_gradient=input_gradient) for i in range(self.num_samples)]
            for x, l in zip(ng_data, loss):
                self.optimizer.tell(x, l)
            
            if use_pbar:
                pbar.set_description("Loss {:.6}".format(np.mean(loss)))
            else:
                print("Round {} - Loss {:.6}".format(r, np.mean(loss)))
        #TODO
        recommendation = self.optimizer.provide_recommendation()
        x_res = torch.from_numpy(recommendation.value).float().to(input_gradient[0].device)
        
        loss_res = self.evaluate_loss(recommendation.value, labels, input_gradient)
        

        return x_res,loss_res

    @staticmethod
    def infer_label(input_gradient, num_inputs=1):
        last_weight_min = torch.argsort(torch.sum(input_gradient[-2], dim=-1), dim=-1)[:num_inputs]
        labels = last_weight_min.detach().reshape((-1,)).requires_grad_(False)
        return labels

    @staticmethod
    def ng_loss(dummyx, # latent variable to be optimized
                loss_fn, # loss function for FL model
                input_gradient,
                labels,
                model,
                fl_model,
                num_classes=1000,
                metric='l2',
                use_tanh=True,
                weight=None, # weight to be applied when calculating the gradient matching loss
                defense_setting=None # adaptive attack against defense
               ):
        dummyx=torch.Tensor(dummyx).float().to(input_gradient[0].device)
        
        
        if use_tanh:
            z = z.tanh()

        c = torch.nn.functional.one_hot(labels, num_classes=num_classes).to(input_gradient[0].device)

        loss_fn=Classification()
        target_loss, _, _ = loss_fn(model(dummyx), labels)
        trial_gradient = torch.autograd.grad(target_loss, fl_model.parameters())
        trial_gradient = [grad.detach() for grad in trial_gradient]

        # adaptive attack against defense
        if defense_setting is not None:
            if 'noise' in defense_setting:
                pass
            if 'clipping' in defense_setting:
                trial_gradient = defense.gradient_clipping(trial_gradient, bound=defense_setting['clipping'])
            if 'compression' in defense_setting:
                trial_gradient = defense.gradient_compression(trial_gradient, percentage=defense_setting['compression'])
            if 'representation' in defense_setting: # for ResNet
                mask = input_gradient[-2][0]!=0
                trial_gradient[-2] = trial_gradient[-2] * mask

        if weight is not None:
            assert len(weight) == len(trial_gradient)
        else:
            weight = [1]*len(trial_gradient)

        # calculate l2 norm
        dist = 0
        pnorm = [0, 0]
        for i in range(len(trial_gradient)):
            if metric == 'l2':
                dist += ((trial_gradient[i] - input_gradient[i]).pow(2)).sum()*weight[i]
            elif metric == 'l1':
                dist += ((trial_gradient[i] - input_gradient[i]).abs()).sum()*weight[i]
            elif metric == "sim":
                dist -= (trial_gradient[i] * input_gradient[i]).sum() * weight[i]
                pnorm[0] += trial_gradient[i].pow(2).sum() * weight[i]
                pnorm[1] += input_gradient[i].pow(2).sum() * weight[i]
                dist = 1 + dist / pnorm[0].sqrt() / pnorm[1].sqrt()
        dist /= len(trial_gradient)

        

        return dist.item()    
    
class NGReconstructor_B():
    """
    Reconstruction for BigGAN

    """
    def __init__(self, fl_model, generator, loss_fn, num_classes=1000, search_dim=(2,128), strategy='CMA', budget=500, use_tanh=True, use_weight=False, defense_setting=None):

        self.generator = generator
        self.budget = budget
        self.search_dim = search_dim
        self.use_tanh = use_tanh
        self.num_samples = 50
        self.weight = None
        self.defense_setting = defense_setting

        parametrization = ng.p.Array(init=np.random.rand(search_dim[0],search_dim[1]))
        self.optimizer = ng.optimizers.registry[strategy](parametrization=parametrization, budget=budget)

        self.fl_setting = {'loss_fn':loss_fn, 'fl_model':fl_model, 'num_classes':num_classes}

        if use_weight:
            self.weight = np.ones(62,)
            for i in range(0, 20):
                self.weight[3*i:3*(i+1)] /= 2**i


    def evaluate_loss(self, z, labels, input_gradient):
        return self.ng_loss(z=z, input_gradient=input_gradient, metric='l2',
                        labels=labels, generator=self.generator, weight=self.weight,
                        use_tanh=self.use_tanh, defense_setting=self.defense_setting, **self.fl_setting
                       )

    def reconstruct(self, input_gradient,labels, use_pbar=True):

        labels = labels
        print('Inferred label: {}'.format(labels))

        if self.defense_setting is not None:
            if 'clipping' in self.defense_setting:
                total_norm = torch.norm(torch.stack([torch.norm(g, 2) for g in input_gradient]), 2)
                self.defense_setting['clipping'] = total_norm.item()
                print('Estimated defense parameter: {}'.format(self.defense_setting['clipping']))
            if 'compression' in self.defense_setting:
                n_zero, n_param = 0, 0
                for i in range(len(input_gradient)):
                    n_zero += torch.sum(input_gradient[i]==0)
                    n_param += torch.numel(input_gradient[i])
                self.defense_setting['compression'] = 100 * (n_zero/n_param).item()
                print('Estimated defense parameter: {}'.format(self.defense_setting['compression']))

        c = torch.nn.functional.one_hot(labels, num_classes=self.fl_setting['num_classes']).to(input_gradient[0].device)

        pbar = tqdm(range(self.budget)) if use_pbar else range(self.budget)

        for r in pbar:
            ng_data = [self.optimizer.ask() for _ in range(self.num_samples)]
            loss = [self.evaluate_loss(z=ng_data[i].value, labels=labels, input_gradient=input_gradient) for i in range(self.num_samples)]
            for z, l in zip(ng_data, loss):
                self.optimizer.tell(z, l)

            if use_pbar:
                pbar.set_description("Loss {:.6}".format(np.mean(loss)))
            else:
                print("Round {} - Loss {:.6}".format(r, np.mean(loss)))


        recommendation = self.optimizer.provide_recommendation()
        z_res = torch.from_numpy(recommendation.value).to(input_gradient[0].device)
        if self.use_tanh:
            z_res = z_res.tanh()
        loss_res = self.evaluate_loss(recommendation.value, labels, input_gradient)
        with torch.no_grad():
            x_res = self.generator(z_res.float(), c.float(), 1)
        x_res = nn.functional.interpolate(x_res, size=(128, 128), mode='area')
        img_res = convert_to_images(x_res.cpu())

        return z_res, x_res, img_res, loss_res

    @staticmethod
    def infer_label(input_gradient, num_inputs=1):
        last_weight_min = torch.argsort(torch.sum(input_gradient[-2], dim=-1), dim=-1)[:num_inputs]
        labels = last_weight_min.detach().reshape((-1,)).requires_grad_(False)
        return labels

    @staticmethod
    def ng_loss(z, # latent variable to be optimized
                loss_fn, # loss function for FL model
                input_gradient,
                labels,
                generator,
                fl_model,
                num_classes=1000,
                metric='l2',
                use_tanh=True,
                weight=None, # weight to be applied when calculating the gradient matching loss
                defense_setting=None # adaptive attack against defense
               ):

        z = torch.Tensor(z).to(input_gradient[0].device)
        if use_tanh:
            z = z.tanh()

        c = torch.nn.functional.one_hot(labels, num_classes=num_classes).to(input_gradient[0].device)
        #print(z.shape,c.shape)
        with torch.no_grad():
            x = generator(z, c.float(), 1)

        x = nn.functional.interpolate(x, size=(256, 256), mode='area')

        # compute the trial gradient
        #print(c.float().shape)
        #print(fl_model(x).shape)
        #print(labels.shape)
        
        loss_fn=Classification()
        target_loss, _, _ = loss_fn(fl_model(x), labels)
        trial_gradient = torch.autograd.grad(target_loss, fl_model.parameters())
        trial_gradient = [grad.detach() for grad in trial_gradient]

        # adaptive attack against defense
        if defense_setting is not None:
            if 'noise' in defense_setting:
                pass
            if 'clipping' in defense_setting:
                trial_gradient = defense.gradient_clipping(trial_gradient, bound=defense_setting['clipping'])
            if 'compression' in defense_setting:
                trial_gradient = defense.gradient_compression(trial_gradient, percentage=defense_setting['compression'])
            if 'representation' in defense_setting: # for ResNet
                mask = input_gradient[-2][0]!=0
                trial_gradient[-2] = trial_gradient[-2] * mask

        if weight is not None:
            assert len(weight) == len(trial_gradient)
        else:
            weight = [1]*len(trial_gradient)

        # calculate l2 norm
        dist = 0
        pnorm = [0, 0]
        for i in range(len(trial_gradient)):
            if metric == 'l2':
                dist += ((trial_gradient[i] - input_gradient[i]).pow(2)).sum()*weight[i]
            elif metric == 'l1':
                dist += ((trial_gradient[i] - input_gradient[i]).abs()).sum()*weight[i]
            elif metric == "sim":
                dist -= (trial_gradient[i] * input_gradient[i]).sum() * weight[i]
                pnorm[0] += trial_gradient[i].pow(2).sum() * weight[i]
                pnorm[1] += input_gradient[i].pow(2).sum() * weight[i]
                dist = 1 + dist / pnorm[0].sqrt() / pnorm[1].sqrt()
        dist /= len(trial_gradient)

        if not use_tanh:
            KLD = -0.5 * torch.sum(1 + torch.log(torch.std(z.squeeze(), unbiased=False, axis=-1).pow(2) + 1e-10) - torch.mean(z.squeeze(), axis=-1).pow(2) - torch.std(z.squeeze(), unbiased=False, axis=-1).pow(2))
            dist += 0.1*KLD

        return dist.item()    

class Loss:
    """Abstract class, containing necessary methods.

    Abstract class to collect information about the 'higher-level' loss function, used to train an energy-based model
    containing the evaluation of the loss function, its gradients w.r.t. to first and second argument and evaluations
    of the actual metric that is targeted.

    """

    def __init__(self):
        """Init."""
        pass

    def __call__(self, reference, argmin):
        """Return l(x, y)."""
        raise NotImplementedError()
        return value, name, format

    def metric(self, reference, argmin):
        """The actually sought metric."""
        raise NotImplementedError()
        return value, name, format    
class Classification(Loss):
    """A classical NLL loss for classification. Evaluation has the softmax baked in.

    The minimized criterion is cross entropy, the actual metric is total accuracy.
    """

    def __init__(self):
        """Init with torch MSE."""
        self.loss_fn = torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100,
                                                 reduce=None, reduction='mean')

    def __call__(self, x=None, y=None):
        """Return l(x, y)."""
        name = 'CrossEntropy'
        format = '1.5f'
        if x is None:
            return name, format
        else:
            value = self.loss_fn(x, y)
            return value, name, format

    def metric(self, x=None, y=None):
        """The actually sought metric."""
        name = 'Accuracy'
        format = '6.2%'
        if x is None:
            return name, format
        else:
            value = (x.data.argmax(dim=1) == y).sum().float() / y.shape[0]
            return value.detach(), name, format
